﻿using Microsoft.Extensions.Logging;
using System.Net;
using System.Net.Sockets;

namespace Away.NetMap.Nmap;

public class PortScannerParams
{
    public string IP { get; set; } = null!;
    public int Start { get; set; } = 1;
    public int End { get; set; } = 65535;
    /// <summary>
    /// 线程数
    /// </summary>
    public int Threads { get; set; } = 100;
    /// <summary>
    /// 超时时间
    /// </summary>
    public int Timeout { get; set; } = 300;
}

[ServiceInject(ServiceLifetime.Scoped, true)]
public class PortScannerService
{
    private readonly IServiceScope _scope;
    public PortScannerService(IServiceProvider serviceProvider)
    {
        _scope = serviceProvider.CreateScope();
        ScanCompleted += Completed;
    }

    private IServiceProvider ServiceProvider => _scope.ServiceProvider;
    private PortRepository Rep => ServiceProvider.GetRequiredService<PortRepository>();
    private ILogger<PortScannerService> Logger => ServiceProvider.GetRequiredService<ILogger<PortScannerService>>();

    private PortScannerParams? _params;
    public void Run(PortScannerParams p, CancellationToken token)
    {
        System.Diagnostics.Stopwatch stopwatch = new();
        stopwatch.Start();

        _params = p;
        var total = p.End - p.Start + 1;
        Logger.LogInformation("开始扫描端口区间：{} [{},{}] 共{}个", p.IP, p.Start, p.End, total);
        var threads = p.Threads;
        var port = p.Start;
        while (true)
        {
            if (token.IsCancellationRequested)
            {
                break;
            }
            var num = threads;
            if (port > p.End)
            {
                break;
            }
            var caps = p.End - port;
            if (caps > 0 && caps < threads)
            {
                num = caps + 1;
            }

            Enumerable.Range(port, num)
            .AsParallel()
            .OrderByDescending(o => o)
            .WithDegreeOfParallelism(threads)
            .WithCancellation(token)
            .ForAll(ScanOne);

            port += threads;
        }
        stopwatch.Stop();
        Logger.LogInformation("结束扫描端口区间：{} [{},{}] 共{}个 耗时：{}", p.IP, p.Start, p.End, total, stopwatch.Elapsed);
    }

    public void ScanOne(int port)
    {
        var ip = _params!.IP;
        var timeout = _params.Timeout;

        if (TcpScan(ip, port, timeout))
        {
            Completed(new ScanCompletedEventArgs { IP = ip, Port = port, Success = true, Protocol = "TCP" });
            return;
        }

        if (UdpScan(ip, port, timeout))
        {
            Completed(new ScanCompletedEventArgs { IP = ip, Port = port, Success = true, Protocol = "UDP" });
            return;
        }

        Completed(new ScanCompletedEventArgs { IP = ip, Port = port, Success = false });
    }

    private static bool TcpScan(string ip, int port, int timeout)
    {
        try
        {
            using TcpClient tcpClient = new();
            tcpClient.SendTimeout = timeout;
            tcpClient.ReceiveTimeout = timeout;
            tcpClient.Connect(ip, port);
            return true;
        }
        catch
        {
            return false;
        }
    }

    private static bool UdpScan(string ip, int port, int timeout)
    {
        try
        {
            using UdpClient udpClient = new();
            udpClient.Client.SendTimeout = timeout;
            udpClient.Client.ReceiveTimeout = timeout;
            byte[] sendData = new byte[1];
            udpClient.Send(sendData, sendData.Length, ip, port);
            IPEndPoint? remoteEP = null;
            byte[] receiveData = udpClient.Receive(ref remoteEP);
            return true;
        }
        catch
        {
            return false;
        }
    }


    private delegate void ScanCompletedEventHandler(ScanCompletedEventArgs args);
    private event ScanCompletedEventHandler ScanCompleted;
    private static readonly object _lock = new();
    private void Completed(ScanCompletedEventArgs args)
    {
        lock (_lock)
        {
            if (args.Success)
            {
                Logger.LogTrace("{}://{}:{}", args.Protocol, args.IP, args.Port);
                Rep.Update(args.IP, args.Port);
            }
        }
    }

}

public class ScanCompletedEventArgs : EventArgs
{
    public string IP { get; set; } = null!;
    public int Port { get; set; }
    public string? Protocol { get; set; }
    public bool Success { get; set; }
}